## Replication Code for
## "Higher Education and Cultural Liberalism"
## Apfeld, Coman, Gerring, and Jessee
## Journal of Experimental Political Science

## NOTE: This code provides a modified version of the rddensityplot
## function provided by Cattaneo et al. (2020) rddensity package.
## This package should be installed prior to sourcing this code
## so that all dependencies of that package (notably lpdensity)
## are available.

## MODIFICATIONS OF ORIGINAL:
## 1. Internal lpdensity.plot.modified modified lpdensity.plot. It 
##    allows the user to plot two separate density estimates at the same time
## 2. rplotdensity_two_lines modified rplotdensity in two ways. First, it
##    calculates two separate density lines to pass to lpdensity.plot.modified.
##    Second, the function returns a ggplot object instead of printing
##    figure directly. This allows the user to save the object for
##    additional aesthetic changes or to combine multiple subplots into a
##    single figure.

## Package required (version number used in paper in parentheses):
##     rddensity (2.1)


lpdensity.plot_modified <- function(..., alpha=NULL, type=NULL, CItype=NULL,
                                    title="", xlabel="", ylabel="",
                                    lty=NULL, lwd=NULL, lcol=NULL, pty=NULL, pwd=NULL, pcol=NULL,
                                    CIshade=NULL, CIcol=NULL, legendTitle=NULL, legendGroups=NULL) {

  ########################################
  # check how many series are passed in
  ########################################

  x <- list(...)
  nfig <- length(x)
  if (nfig == 0) stop("Nothing to plot.\n")

  ########################################
  # error handling
  ########################################
  # alpha
  if (length(alpha) == 0) {
    alpha <- rep(0.05, nfig)
  } else if (!all(alpha>0 & alpha<1)) {
    stop("Significance level incorrectly specified.\n")
  } else {
    alpha <- rep(alpha, length.out=nfig)
  }

  # plot type
  if (length(type) == 0) {
    type <- rep("line", nfig)
  } else {
    if (!all(type%in%c("line", "points", "both"))) {
      stop("Plotting type incorrectly specified.\n")
    }
    type <- rep(type, length.out=nfig)
  }

  # CI type
  if (length(CItype) == 0) {
    CItype <- rep("region", nfig)
  } else {
    if (!all(CItype%in%c("region", "line", "ebar", "all", "none"))) {
      stop("Confidence interval type incorrectly specified.\n")
    }
    CItype <- rep(CItype, length.out=nfig)
  }

  # line style, line width, line color
  if (length(lty) == 0) {
    lty <- rep(1, nfig)
  } else {
    lty <- rep(lty, length.out=nfig)
  }
  if (length(lwd) == 0) {
    lwd <- rep(0.5, nfig)
  } else {
    lwd <- rep(lwd, length.out=nfig)
  }
  if (length(lcol) == 0) {
    lcol <- 1:nfig
  } else {
    lcol <- rep(lcol, length.out=nfig)
  }

  # point style, point width, point color
  if (length(pty) == 0) {
    pty <- rep(1, nfig)
  } else {
    pty <- rep(pty, length.out=nfig)
  }
  if (length(pwd) == 0) {
    pwd <- rep(1, nfig)
  } else {
    pwd <- rep(pwd, length.out=nfig)
  }
  if (length(pcol) == 0) {
    pcol <- lcol
  } else {
    pcol <- rep(pcol, length.out=nfig)
  }

  # CI shade, CI color
  if (length(CIshade) == 0) {
    CIshade <- rep(0.2, nfig)
  } else {
    CIshade <- rep(CIshade, length.out=nfig)
  }
  if (length(CIcol) == 0) {
    CIcol <- lcol
  } else {
    CIcol <- rep(CIcol, length.out=nfig)
  }

  # legend
  # New in v0.2.1 to handle legend
  if (length(legendTitle) == 0) {
    legendTitle <- ""
  } else {
    legendTitle <- legendTitle[1]
  }
  if (length(legendGroups) > 0) {
    legendGroups <- rep(legendGroups, length.out=nfig)
    legend_default <- FALSE
  } else {
    legend_default <- TRUE
  }

  ########################################
  # initializing plot
  ########################################
  temp_plot <- ggplot() + theme_bw() #+ theme(legend.position="none")

  CI_l <- CI_r <- f_p <- grid <- Sname <- NULL

  ########################################
  # looping over input models
  ########################################
  ### all colors
  col_all <- lty_all <- pty_all <- c()
  for (i in 1:nfig) {
    data_x <- data.frame(x[[i]]$Estimate[, c("grid", "f_p", "f_q", "se_p", "se_q")])
    z_val <- qnorm(1 - alpha[i]/2)
    if (x[[i]]$opt$q == x[[i]]$opt$p) {
      data_x$f_q <- data_x$f_p; data_x$se_q <- data_x$se_p
    }
    data_x$CI_l <- data_x$f_q - z_val * data_x$se_q
    data_x$CI_r <- data_x$f_q + z_val * data_x$se_q

    # New in v0.2.1 to handle legend
    if (legend_default) {
      data_x$Sname <- paste("Series", i, sep=" ")
      legendGroups <- c(legendGroups, data_x$Sname)
    } else {
      data_x$Sname <- legendGroups[i]
    }

    ########################################
    # add CI regions to the plot
    if (CItype[i]%in%c("region", "all"))
      temp_plot <- temp_plot + geom_ribbon(data=data_x, aes(x=grid, ymin=CI_l, ymax=CI_r), alpha=CIshade[i], fill=CIcol[i])

    ########################################
    # add CI lines to the plot
    if (CItype[i]%in%c("line", "all"))
      temp_plot <- temp_plot + geom_line(data=data_x, aes(x=grid, y=CI_l), linetype=2, alpha=CIshade[i], col=CIcol[i]) +
        geom_line(data=data_x, aes(x=grid, y=CI_r), linetype=2, alpha=CIshade[i], col=CIcol[i])

    ########################################
    # add error bars to the plot
    if (CItype[i]%in%c("ebar", "all"))
      temp_plot <- temp_plot + geom_errorbar(data=data_x, aes(x=grid, ymin=CI_l, ymax=CI_r), alpha=CIshade[i], col=CIcol[i], linetype=1)

    ########################################
    # add lines to the plot
    # Edits in here!
    if (type[i]%in%c("line", "both")) {
      ## temp_plot <- temp_plot + geom_line(data=data_x, aes(x=grid, y=f_p, colour=Sname, linetype=Sname), size=lwd[i])
      temp_plot <- temp_plot + geom_line(data=data_x, aes(x=grid, y=f_p, colour=Sname, linetype="solid"), size=lwd[i])
      temp_plot <- temp_plot + geom_line(data=data_x, aes(x=grid, y=f_q, colour=Sname, linetype="dotted"), size=lwd[i])
    }

    ########################################
    # add points to the plot
    if (type[i]%in%c("points", "both")) {
      temp_plot <- temp_plot + geom_point(data=data_x, aes(x=grid, y=f_p, colour=Sname, shape=Sname), size=pwd[i])
    }

    if (type[i] == "line") {
      col_all <- c(col_all, lcol[i])
      lty_all <- c(lty_all, lty[i])
      pty_all <- c(pty_all, NA)
    } else if (type[i] == "both") {
      col_all <- c(col_all, lcol[i])
      lty_all <- c(lty_all, lty[i])
      pty_all <- c(pty_all, pty[i])
    } else {
      col_all <- c(col_all, pcol[i])
      lty_all <- c(lty_all, NA)
      pty_all <- c(pty_all, pty[i])
    }
  }

  ########################################
  # change color, line type and point shape back, and customize legend
  ########################################
  # New in v0.2.1 to handle legend
  index <- sort.int(legendGroups, index.return=TRUE)$ix
  temp_plot <- temp_plot + scale_color_manual(values = col_all[index]) +
    # scale_linetype_manual(values = c("solid", "solid", "dashed", "dashed")) +
    scale_shape_manual(values = pty_all[index]) +
    guides(colour=guide_legend(title=legendTitle)) +
    guides(linetype=guide_legend(title=legendTitle)) +
    guides(shape=guide_legend(title=legendTitle)) +
    scale_y_continuous(limits = c(0, 0.25))

  ########################################
  # add title, x and y labs
  ########################################
  temp_plot <- temp_plot + labs(x=xlabel, y=ylabel) + ggtitle(title)

  ########################################
  # return the plot
  ########################################
  return (temp_plot)
}

# And now a modified function to plot both of the lines
rdplotdensity_two_lines <- function(rdd, X, plotRange = NULL, plotN = 10, plotGrid = c("es", "qs"),
                                    alpha = 0.05,
                                    type = NULL, CItype = NULL,
                                    title = "", xlabel = "", ylabel = "",
                                    lty = NULL, lwd = NULL, lcol = NULL,
                                    pty = NULL, pwd = NULL, pcol = NULL,
                                    CIshade = NULL, CIcol = NULL,
                                    legendTitle = NULL, legendGroups = NULL){

  # obtain options from rddensity result
  c       <- rdd$opt$c
  p       <- rdd$opt$p
  q       <- rdd$opt$q
  hl      <- rdd$h$left
  hr      <- rdd$h$right
  kernel  <- rdd$opt$kernel

  # check grid specifications
  if (length(plotRange) == 0) {
    plotRange <- c( max(min(X), c - 3*hl), min(max(X), c + 3 * hr) )
  } else if (length(plotRange) != 2) {
    stop("Plot range incorrectly specified.\n")
  } else if (plotRange[1] >= c | plotRange[2] <= c) {
    stop("Plot range incorrectly specified.\n")
  }

  if (length(plotN) == 0) {
    plotN <- c(10, 10)
  } else if (length(plotN) == 1) {
    plotN <- c(plotN, plotN)
  } else if (length(plotN) > 2) {
    stop("Number of grid points incorrectly specified.\n")
  }
  if (plotN[1] <=1 | plotN[2] <=1) {
    stop("Number of grid points incorrectly specified.\n")
  }

  if (length(plotGrid) == 0) {
    plotGrid <- "es"
  } else {
    plotGrid <- plotGrid[1]
  }
  if (!plotGrid%in%c("es", "qs")) {
    stop("Grid specification invalid.\n")
  }

  # some preparation
  scalel <- (sum(X <= c) - 1) / (length(X) - 1)
  scaler <- (sum(X >= c) - 1) / (length(X) - 1)

  if (plotGrid == "es") {
    gridl <- seq(plotRange[1], c, length.out=plotN[1])
    gridl[plotN[1]] <- c
    gridr <- seq(c, plotRange[2], length.out=plotN[2])
    gridr[1] <- c
  } else {
    gridl <- seq(mean(X <= plotRange[1]), mean(X <= c), length.out=plotN[1])
    gridl <- quantile(X, gridl)
    gridr <- seq(mean(X <= c), mean(X <= plotRange[2]), length.out=plotN[2])
    gridr <- quantile(X, gridr)
    gridl[plotN[1]] <- c
    gridr[1] <- c
  }

  # call lpdensity
  Estl <- lpdensity::lpdensity(data=X[X<=c], grid=gridl, bw=hl, p=p, q=q, v=1, kernel=kernel, scale=scalel)
  Estr <- lpdensity::lpdensity(data=X[X>=c], grid=gridr, bw=hr, p=p, q=q, v=1, kernel=kernel, scale=scaler)

  # call lpdensity.plot
  Estplot <- lpdensity.plot_modified(Estl,
                                     Estr,
                                     alpha = alpha,
                                     type = type,
                                     CItype = CItype,
                                     title = title,
                                     xlabel = xlabel,
                                     ylabel = ylabel,
                                     lty = lty,
                                     lwd = lwd,
                                     lcol = lcol,
                                     pty = pty,
                                     pwd = pwd,
                                     pcol = pcol,
                                     CIshade = CIshade,
                                     CIcol = CIcol,
                                     legendTitle = legendTitle,
                                     legendGroups = legendGroups) +
    theme(legend.position = "none")

  #  print(Estplot)

  # return(list(Estl=Estl, Estr=Estr, Estplot=Estplot))
  return(Estplot)
}
